import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(current_dir, '..')))
from datetime import datetime

import torch
from torch.autograd import Variable
import torch.nn as nn
import numpy as np
import wandb
from tqdm import tqdm
import copy

from evaluate.robust.attack_pgd import pgd
from vis.img_plt import plot_img
from train.mlbase import MLBase
from model.model_getter import get_model
from tool.args import get_general_args
from tool.util import init_wandb
from tool.logger import Logger
from evaluate.ood import ana_oodauc
from evaluate.ood import ana_logp_hist
from evaluate.calibration import calibration
from evaluate.landscape import plot_logit_landscape
from model.model_io import _save_model, save_on_master
from model.model_io import load_pretrained_weights

from data.dl_getter import get_dl_robust
from tool.util import get_valid_unit


class Evaluator(MLBase):

    def __init__(self, p):
        super().__init__(other=p)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.best_acc = 0.
        self.min_fid = 1e3
        self.max_is = 0
        self.max_oodacu = 0.
        self.max_robust = 0.
        self.min_ece = 1e3
        self.log = Logger()
        self.text_log = dict()
        self.robust_dl = get_dl_robust(self.args)

    def __call__(self,
                 epoch=None,
                 evals=['standard', 'ood', 'calri', 'landscape',
                        'fidis', 'robust']):
        self.model.eval()
        if self.args.head_eval_clip:
            self.model.head.set_eval()
        is_last_epoch = (epoch == self.args.epochs - 1 or epoch == None)
        if is_last_epoch and epoch is not None:
            print(f"Loading best acc checkpoint")
            path = self.args.chkpt_path[:-4]+f'_acc.pth'
            load_pretrained_weights(self.model, path)

        if 'standard' in evals:
            self.eval_standard(epoch)
            # self.eval_inex()
        # if self.args.ebm and (epoch is None or ('fidis' in evals and epoch % self.args.every_fidis)):
        # if epoch is None or ('fidis' in evals and epoch % self.args.every_fidis):
            #self.eval_fidis(epoch)
        if epoch is None or ('ood' in evals and (((epoch+1 >= self.args.auc_epoch) and (epoch % 5 == 0))
                                                 or is_last_epoch)):
            self.eval_ood(epoch, is_last_epoch)
        if epoch is None or ('calri' in evals and is_last_epoch):
            self.eval_calri(epoch)
        # if 'landscape' in evals and is_last_epoch:
            # self.eval_landscape()
        if epoch is None or ('robust' in evals and is_last_epoch):
            self.eval_robust(epoch)
#==============================================================================
        # if 'ana' in evals and is_last_epoch:
            # self.eval_ana()
        self.write_log(is_last_epoch)
        if is_last_epoch:
            print(f"Evaluation finished at [{datetime.now():%Y.%m.%d %H:%M:%S}]")

        self.write_log(is_last_epoch)

    def eval_inex(self):
        from evaluate.inex import check_inex
        check_inex(self.model)


    def eval_ood(self, epoch, is_last_epoch):
        from evaluate.ood import ana_entropy, ana_fpr95
        print("OOD evaluation")
        if is_last_epoch and epoch is None: self.load_checkpoint('ood');
        total_score = ana_oodauc(self.model, self.vl_dl, self.text_log, self.args)
        if is_last_epoch:
            ana_logp_hist(self.model, self.vl_dl, self.args.output_path, self.args)
            ana_entropy(self.model, self.vl_dl, self.text_log, self.args)
            ana_fpr95(self.model, self.vl_dl, self.text_log, self.args)
        if self.max_oodacu < total_score and epoch is not None:
            self.max_oodacu = total_score
            self.save_checkpoint(epoch, 'ood', total_score)


    def eval_calri(self, epoch):
        print("Calibration evaluation")
        if epoch is None: self.load_checkpoint('calri');
        ece, ece_h, acc_h = calibration(self.model, self.vl_dl,
            self.args.output_path, self.text_log, args=self.args)
        if self.min_ece > ece and epoch is not None:
            self.min_ece = ece
            self.save_checkpoint(epoch, 'calri', ece)
        ece=get_valid_unit(ece)
        ece_h=get_valid_unit(ece_h)
        acc_h=get_valid_unit(acc_h)

        self.text_log['calibration'] = dict()
        self.text_log['calibration']['ece'] = ece
        self.text_log['calibration']['ece_h'] = ece_h
        self.text_log['calibration']['acc_h'] = acc_h

    # def eval_ana(self):
    #     from analysis.main import anl_all
    #     anl_all(self.args, self.model, self.vl_dl)

    def eval_landscape(self):
        print("Landscape evaluation")
        plot_logit_landscape(
            self.model,
            self.vl_dl,
            self.args.output_path,
            self.args.landscape,
            self.args.landscape_step,
            self.args.landscape_range,
            self.args.landscape_no_record)

    def eval_fidis(self, epoch):
        from evaluate.gen.eval_buffer import cond_is_fid
        print("FID/IS evaluation")
        # from ebm_tool.sadajem_tool import get_buffer
        # self.replay_buffer = get_buffer(self.args)
        if epoch is None: self.load_checkpoint('is');
        is_score, std, fid = cond_is_fid(
            self.model, self.replay_buffer, self.vl_dl, self.args, ratio=0.1, eval='is')
        if epoch is None: self.load_checkpoint('fid');
        _, _, fid_score = cond_is_fid(
            self.model, self.replay_buffer, self.vl_dl, self.args, ratio=0.9, eval='fid')

        if self.min_fid > fid_score and epoch is not None:
            self.min_fid = fid_score
            self.save_checkpoint(epoch, 'fid', fid_score)
        if self.max_is < is_score and epoch is not None:
            self.max_is = is_score
            self.save_checkpoint(epoch, 'is', is_score)
        print("fid: {}, is: {} ({})".format(fid_score, is_score, std))
        wandb.log({'gen/fid': fid_score,
                    'gen/is': is_score,
                    'gen/is_std': std }, commit=False)

    def eval_robust(self, epoch):
        from evaluate.robust.pertb import PerturbAnalysis
        if epoch is None: self.load_checkpoint('robust');
        if self.args.cw:
            attacks = ['pgd-20', 'ifgsm-20', 'l2-20', 'cw', 'rand-20']
        else:
            attacks = ['pgd-20', 'ifgsm-20', 'l2-20', 'rand-20']

        pertb_anal = PerturbAnalysis(
            attacks=attacks, model=self.model, dloader=self.robust_dl,
            rootdir=self.args.output_path,
            args=self.args,
            overwrite=self.args.overwrite)
        pertb_anal.ad_detect(self.text_log)
        pertb_anal.plot()
        scores = pertb_anal.print()
        if self.max_robust < sum(scores):
            self.max_robust = sum(scores)
            self.save_checkpoint(epoch, 'robust', sum(scores))

        self.text_log['adv. acc'] = dict()
        for attack, score in zip(attacks, scores):
            self.text_log['adv. acc'][attack] = get_valid_unit(score)

    def eval_robust2(self):
        from evaluate.robust.attack_pgd import attack_pgd
        from evaluate.robust.attack_cw import attack_cw
        print(f"Robustness evaluation {self.args.attack}")
        args = self.args
        num_steps = self.args.num_steps
        attack = args.attack

        attack_params = {
            'epsilon' : args.epsilon,
            'seed' : args.random_seed
        }
        if attack == 'pgd':
            attack_params.update({
                'num_restarts' : args.num_restarts,
                'step_size' : args.step_size,
                'num_steps' : num_steps,
                'random_start' : args.random_start,
            })
            print(f'Running {attack_params}')
            attack_pgd(args, self.model, self.vl_dl, attack_params)
        elif attack == 'cw':
            raise NotImplementedError
            attack_params.update({
                'binary_search_steps': args.binary_search_steps,
                'max_iterations': args.max_iterations,
                'learning_rate': args.learning_rate,
                'initial_const': args.initial_const,
                'tau_decrease_factor': args.tau_decrease_factor
            })
            print(f'Running {attack_params}')
            self.attack_cw(attack_params)

    @torch.no_grad()
    def eval_standard(self, epoch):
        print("Standard evaluation")
        if epoch is None: self.load_checkpoint('acc');
        self.log.reset()
        for x, y in self.vl_dl:
            x = x.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True)
            bsz = y.shape[0]

            logits = self.model(x)
            loss = self.criterion(logits, y)
            acc1, _ = self.accuracy(logits, y, topk=(1, 5))
            self.log.update(bsz, acc1=acc1, loss=loss)

        if self.best_acc < self.log.acc1.avg and epoch is not None:
            self.save_checkpoint(epoch, 'acc', self.log.acc1.avg)

        self.best_acc = max(self.best_acc, self.log.acc1.avg)
        d, s = self.log.out()
        print('Test: [{0}]\t'.format(epoch) + s + \
                f" best acc: {self.best_acc:.2f}")
        wandb.log({**{'info/epoch': epoch,
                    'info/lr': self.optimizer.param_groups[0]["lr"],
                    'test/best_acc': self.best_acc}, **d}, commit=True)
        ep_acc = get_valid_unit(self.log.acc1.avg)
        self.text_log['stanard acc'] = dict()
        self.text_log['stanard acc']['acc'] = ep_acc

    @torch.no_grad()
    def accuracy(self, output, target, topk=(1,)):
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]

    def load_checkpoint(self, task):
        path = self.args.load_path + f'/chkpt_final_{task}.pth'
        if os.path.isfile(path):
            print(f"Loading best {task} checkpoint")
            load_pretrained_weights(self.model, path)
        else:
            print(f"Cannot Find best {task} checkpoint")
            print(f"Load best acc checkpoint")
            load_pretrained_weights(self.model, self.args.load_path + '/chkpt_final_acc.pth')

    def save_checkpoint(self, epoch, task, score):
        save_dict = {"epoch": epoch, task: score,
                     "model_dict": self.model.state_dict(),
                     "optimizer": self.optimizer.state_dict(),
                     "replay_buffer" : self.replay_buffer}
        save_on_master(save_dict, self.args.chkpt_path[:-4]+f'_{task}.pth')

    def write_log(self, is_last_epoch):
        with open(os.path.join(self.args.output_path, 'log_all.txt'), 'w') as file:
            for key in self.text_log:
                file.write(key + '\n')
                for sub_key, value in self.text_log[key].items():
                    file.write(sub_key + ', ')
                file.write('\n')
                for sub_key, value in self.text_log[key].items():
                    file.write(str(value) + '\n')
                file.write('\n')

        if is_last_epoch:
            ds_print_seq = ['cifar10', 'svhn', 'dtd', 'iSUN', 'LSUN',
                            'places365', 'LSUN_R', 'cifar100', 'interp',
                            'N', 'U', 'OODomain', 'Constant']
            score_fn_types = ['p_x', 'p_y|x']
            adv_lst = ['pgd-20', 'ifgsm-20', 'l2-20', 'rand-20']
            metrics = ['fpr95', 'ood_auc', 'ood_aupr']
            adv_metrics = ['adv. auc', 'adv. aupr']

            with open(os.path.join(self.args.output_path, 'log_table.txt'), 'w') as file:
                for score_fn_type in score_fn_types:
                    print(score_fn_type)
                    file.write(score_fn_type + '\n')
                    print("{:.2f}".format(self.text_log['stanard acc']['acc']))
                    file.write(str(self.text_log['stanard acc']['acc']) + '\n')
                    for ds in ds_print_seq:
                        for metric in metrics:
                            print("{:.2f}".format(self.text_log[metric][ds+'_'+score_fn_type]))
                            file.write(str(self.text_log[metric][ds+'_'+score_fn_type]) + '\n')
                    for adv in adv_lst:
                        for adv_metric in adv_metrics:
                            print("{:.2f}".format(self.text_log[adv_metric][adv+'_'+score_fn_type]))
                            file.write(str(self.text_log[adv_metric][adv+'_'+score_fn_type]) + '\n')
                        print("{:.2f}".format(self.text_log['adv. acc'][adv]))
                        file.write(str(self.text_log['adv. acc'][adv]) + '\n')
                    print()
                    file.write('\n')

if __name__ == '__main__':
    args = get_general_args()
    args.eval = True
    init_wandb(args)
    Evaluator(MLBase(args))()
